
# We can reproduce　the experiment in our paper by "python main.py --flip_rate ...".   Default is fixed in 25%.



import argparse
import numpy as np
import torch
import pandas as pd

from torchvision import datasets
from torch import nn, optim, autograd
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D

from Data_generating import data_generator,data_high_generator,high_data_loader,data_loader
from additional_functions import mean_nll,mean_nll2,mean_accuracy,prob_sum,softmax,Condi_MI,mean_nll2_forIRM


parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=440)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.002)
parser.add_argument('--lr', type=float, default=0.0004)
parser.add_argument('--n_restarts', type=int, default=5)
parser.add_argument('--steps', type=int, default=501)
parser.add_argument('--high_env_number', type=int, default=5)
parser.add_argument('--high_env', type=np.array, default=np.array([0.1,0.3,0.5,0.7,0.9]))

parser.add_argument('--flip_rate', type=float, default=0.25)
flags = parser.parse_args()


def pretty_print(*values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

def condi_prob(x):
    return torch.stack([x[:,1]/(x[:,1] + x[:,2]),x[:,2]/(x[:,1] + x[:,2])]).T

class MLP(nn.Module):
        def __init__(self):
            super(MLP, self).__init__()
            lin1 = nn.Linear(2*14*14, flags.hidden_dim)
            lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
            lin3 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
            lin4 = nn.Linear(flags.hidden_dim, 1)     

            low_lin5 = nn.Linear(flags.hidden_dim, 10)



            for lin in [lin1,lin2,lin3, lin4]:
                nn.init.xavier_uniform_(lin.weight)
                nn.init.zeros_(lin.bias)
            for lin in [low_lin5]:
                nn.init.xavier_uniform_(lin.weight)
                nn.init.zeros_(lin.bias)
            self._main2 = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3, lin4)
            self._main3 = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3, low_lin5 )


        def forward_all(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main2(out) 
            return out
        def forward_low2(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            out = softmax(out)
          
            return out
              
        def forward_lower5(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            return out[:,:5]


        def forward_upper5(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            return out[:,5:]

        def forward_z(self, input):
            out = input.view(input.shape[0], 2*14*14)
            out = self._main3(out)
            out = softmax(out)
            return torch.stack([out[:,0]+out[:,1]+out[:,2]+out[:,3]+out[:,4], out[:,5]+out[:,6]+out[:,7]+out[:,8]+out[:,9]]).T

def CV(images, labels, split, out_number):
    out_images = images[out_number::split]
    out_labels = labels[out_number::split]
    X =torch.tensor(range(len(images)))
    Y = (X%split == out_number).float()
    leave_images = images[Y==0,:]
    leave_labels = labels[Y==0,:]
    return [{'images':leave_images, 'labels':leave_labels} , {'images':out_images, 'labels':out_labels}]


def hCV(images, labels,split, out_number):
    out_images = images[out_number::split]
    out_labels = labels[out_number::split]
    
    X =torch.tensor(range(len(images)))
    Y = (X%split == out_number).float()
    leave_images = images[Y==0,:]
    leave_labels = labels[Y==0,:]
    
    return [{'images':leave_images, 'labels':leave_labels} , {'images':out_images, 'labels':out_labels}]





savingdata_Restart0 = []
savingdata_Restart1 = []
savingdata_Restart2 = []
savingdata_Restart3 = []
savingdata_Restart4 = []

for restart in range(flags.n_restarts):
        print('Restart:',restart)
        
        mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
        mnist_train = (mnist.data[:50000], mnist.targets[:50000])
        mnist_val = (mnist.data[50000:], mnist.targets[50000:])

        mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
        mnist_train = (mnist.data[:50000], mnist.targets[:50000])
        mnist_val = (mnist.data[50000:], mnist.targets[50000:])
        rng_state = np.random.get_state()
        np.random.shuffle(mnist_train[0].numpy())
        np.random.set_state(rng_state)
        np.random.shuffle(mnist_train[1].numpy())

        mnist_train_image = mnist_train[0][::2]
        mnist_train_label = mnist_train[1][::2]

        high_mnist_train_image = mnist_train[0][1::2]
        high_mnist_train_label = mnist_train[1][1::2]

        envs = [
            data_generator(mnist_train_image, mnist_train_label, 0.1,flags.flip_rate),

            data_generator(mnist_val[0][::2], mnist_val[1][::2], 0.1,flags.flip_rate),data_generator(mnist_val[0][1::2], mnist_val[1][1::2], 0.9,flags.flip_rate)
        ]

        high_envs = [ data_high_generator(high_mnist_train_image[i::5], high_mnist_train_label[i::5], j,flags.flip_rate) for i, j in enumerate(flags.high_env) ]

        for env in high_envs:
            upper_ratio  = env['labels'][env['labels'].view(-1)==1,:].shape[0]/env['labels'].shape[0]
            print('upper_ratio:',upper_ratio)
            env['up_ratio'] =upper_ratio
            down_ratio  = env['labels'][env['labels'].view(-1)==0,:].shape[0]/env['labels'].shape[0]
            print('down_ratio:',down_ratio)
            env['down_ratio'] =down_ratio

        iters_array = np.array([0,100,200,300])
        penalty_weight_array = np.linspace(0.0, 6.0,7)
        for iters in iters_array:
            for weight_index in range(7):
                penalty_weight = penalty_weight_array[weight_index]
                print('iters:',iters)
                print('penalty_weight:',penalty_weight)

                print('Flags:')
                for k,v in sorted(vars(flags).items()):
                    print("\t{}: {}".format(k, v))

                final_train_accs = []
                final_test_accs= []
                CV_store = []
                sbsCV_store=[]



                print('Starting CV under (iters,weight_grad)=({},{})'.format(iters, penalty_weight))      

                for CV_step in range(10):

                    print('CV_step:',CV_step)
                    envs_leave = [ CV(envs[0]['images'],envs[0]['labels'],10,CV_step)[0],CV(envs[1]['images'],envs[1]['labels'],10,CV_step)[0]   ]

                    envs_out = [ CV(envs[0]['images'],envs[0]['labels'],10,CV_step)[1],CV(envs[1]['images'],envs[1]['labels'],10,CV_step)[1]    ]

                    envs_high_leave = [ hCV(high_envs[i]['images'],high_envs[i]['labels'],10,CV_step)[0] for i in range(flags.high_env_number)]


                    envs_high_out = [ hCV(high_envs[i]['images'],high_envs[i]['labels'],10,CV_step)[1]  for i in range(flags.high_env_number) ]


                    mlp_CV = MLP().cuda()

                    def penalty(logits, y):
                        loss = mean_nll2_forIRM(logits , y)
                        grad = autograd.grad(loss, mlp_CV._main2[5].parameters(), create_graph=True)[0]
                        return torch.sum(grad**2)

                    optimizer = optim.Adam(mlp_CV.parameters(), lr=flags.lr)

                    pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty')

                    for step in range(flags.steps):
 
                        for env in envs_leave:
                            logits_low = mlp_CV.forward_low2(env['images'])
                            env['nll'] = mean_nll2(logits_low , env['labels'])
                            env['acc'] = mean_accuracy(logits_low, env['labels'])
                        for env in envs_high_leave:
                            logits = mlp_CV.forward_all(env['images'])
                            env['penalty'] = penalty(logits, env['labels'])
                        train_nll = envs_leave[0]['nll']
                        train_acc = envs_leave[0]['acc']
                        train_penalty = torch.stack([envs_high_leave[i]['penalty'] for i in range(flags.high_env_number)]).mean()

                        weight_norm = torch.tensor(0.).cuda()
                        for w in mlp_CV.parameters():
                            weight_norm += w.norm().pow(2)  #power:2乗　mlp.parametersの2乗ノルムで

                        loss = train_nll.clone()
                        loss += flags.l2_regularizer_weight * weight_norm
                        penalty_weights = ( 10**(penalty_weight) 
                               if step >=iters else 1.0)
                        loss += penalty_weights * train_penalty
                        
                        if penalty_weight > 1.0:
                          # Rescale the entire loss to keep gradients in a reasonable range
                            loss /= penalty_weight

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()



                        for env in envs_out:
                            logits_low = mlp_CV.forward_low2(env['images'])
                            env['nll'] = mean_nll2(logits_low , env['labels'])
                            logits_up = mlp_CV.forward_upper5(env['images'][env['labels'].view(-1)>4,:,:,:])
                            env['up_bias'] = mean_nll2(logits_up,env['labels'][env['labels'].view(-1)>4,:]-5 )
                            logits_down = mlp_CV.forward_lower5(env['images'][env['labels'].view(-1)<5,:,:,:])
                            env['down_bias'] = mean_nll2(logits_down,env['labels'][env['labels'].view(-1)<5,:] )
                            #print('down_bias:',env['down_bias'])
                        for env in envs_high_out:
                            logits_low = mlp_CV.forward_z(env['images'])
                            env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])



                        if step % 100 == 0:   
                            pretty_print(
                            np.int32(step),
                            train_nll.detach().cpu().numpy(),
                            train_acc.detach().cpu().numpy(),
                            train_penalty.detach().cpu().numpy(),
                            (train_penalty*penalty_weights).detach().cpu().numpy())
                      
                        if step == 500:  
                            print(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll']).view(-1)   for i  in range(flags.high_env_number)]) ]).view(-1))
                            print(torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll']).view(-1)   for i  in range(flags.high_env_number)]) ]).view(-1)))
                            print(torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll']).view(-1) +high_envs[i]['up_ratio'] *envs_out[0]['up_bias']+high_envs[i]['down_ratio'] *envs_out[0]['down_bias']  for i  in range(flags.high_env_number)]) ]).view(-1)).detach().cpu().numpy())
                            
                            
                    CV_store.append(torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll']).view(-1)   for i  in range(flags.high_env_number)]) ]).view(-1)).detach().cpu().numpy())
                    sbsCV_store.append(torch.max(torch.cat([envs_out[0]['nll'].view(-1), torch.cat([(envs_high_out[i]['nll']).view(-1) + high_envs[i]['up_ratio'] *envs_out[0]['up_bias']+high_envs[i]['down_ratio'] *envs_out[0]['down_bias']   for i  in range(flags.high_env_number)]) ]).view(-1)).detach().cpu().numpy())



                print('CVI:',np.mean(CV_store))
                print('CVII:',np.mean(sbsCV_store))




                final_train_accs = []
                final_test1_accs= []
                final_test2_accs= []



                print('Starting training (iters,weight_grad)=({},{})'.format(iters, penalty_weight))

                mlp = MLP().cuda()    
                def penalty(logits, y):
                    loss = mean_nll2_forIRM(logits , y)
                    grad = autograd.grad(loss, mlp._main2[5].parameters(), create_graph=True)[0]
                    return torch.sum(grad**2)



                optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)

                pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty', 'test1acc','test2 acc')

                for step in range(flags.steps):
                        for env in envs:
                                logits_low = mlp.forward_low2(env['images'])
                                env['nll'] = mean_nll(torch.log(logits_low) , env['labels'])
                                env['acc'] = mean_accuracy(logits_low, env['labels'])
                        for env in high_envs:
                            logits = mlp.forward_all(env['images'])
                            env['penalty'] = penalty(logits, env['labels'])
                        train_nll = envs[0]['nll']
                        train_acc = envs[0]['acc']
                        train_penalty = torch.stack([high_envs[i]['penalty'] for i in range(flags.high_env_number)]).mean()

                        weight_norm = torch.tensor(0.).cuda()
                        for w in mlp.parameters():
                            weight_norm += w.norm().pow(2)  #power:2乗　mlp.parametersの2乗ノルムで

                        loss = train_nll.clone()
                        loss += flags.l2_regularizer_weight * weight_norm
                        penalty_weights = ( 10**(penalty_weight) 
                               if step >=iters else 1.0)
                        loss += penalty_weights * train_penalty
                        if penalty_weights > 1.0:
                          # Rescale the entire loss to keep gradients in a reasonable range
                          loss /= penalty_weights

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        test1_acc = envs[1]['acc']
                        test2_acc = envs[2]['acc']
                        if step % 100 == 0:   #stepを100で割った余り
                            pretty_print(
                            np.int32(step),
                            train_nll.detach().cpu().numpy(),
                            train_acc.detach().cpu().numpy(),
                            train_penalty.detach().cpu().numpy(),
                            (train_penalty*penalty_weights).detach().cpu().numpy(),
                            test1_acc.detach().cpu().numpy() ,
                            test2_acc.detach().cpu().numpy() )

                final_train_accs.append(train_acc.detach().cpu().numpy())
                final_test1_accs.append(test1_acc.detach().cpu().numpy())
                final_test2_accs.append(test2_acc.detach().cpu().numpy())
                print('Final train acc (mean/std across restarts so far):')
                print(np.mean(final_train_accs), np.std(final_train_accs))
                print('Final test1 acc (mean/std across restarts so far):')
                print(np.mean(final_test1_accs), np.std(final_test1_accs))
                print('Final test2 acc (mean/std across restarts so far):')
                print(np.mean(final_test2_accs), np.std(final_test2_accs))
                if restart ==0:            
                    savingdata_Restart0.append(np.array([iters,penalty_weight,np.mean(sbsCV_store),np.mean(CV_store),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))

                if restart ==1:             
                    savingdata_Restart1.append(np.array([np.mean(sbsCV_store),np.mean(CV_store),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))


                if restart ==2:            
                    savingdata_Restart2.append(np.array([np.mean(sbsCV_store),np.mean(CV_store),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))

                if restart ==3:             
                    savingdata_Restart3.append(np.array([np.mean(sbsCV_store),np.mean(CV_store),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))


                if restart ==4:            
                    savingdata_Restart4.append(np.array([np.mean(sbsCV_store),np.mean(CV_store),np.mean(final_test1_accs),np.mean(final_test2_accs) ]))


savingdata_Restart0 = np.array(savingdata_Restart0)
savingdata_Restart1 = np.array(savingdata_Restart1)
savingdata_Restart2 = np.array(savingdata_Restart2)
savingdata_Restart3 = np.array(savingdata_Restart3)
savingdata_Restart4 = np.array(savingdata_Restart4)


x =np.concatenate([savingdata_Restart0,savingdata_Restart1,savingdata_Restart2,savingdata_Restart3,savingdata_Restart4] ,axis=1)


sample = pd.DataFrame(x, columns=['iters','penalty_weight','sbstitute_CV','simplymax_CV','test_acc1','test_acc2','sbstitute_CV_1','simplymax_CV_1','test_acc1_res1','test_acc2_res1','sbstitute_CV_2','simplymax_CV_2','test_acc1_res2','test_acc2_res2','sbstitute_CV_3','simplymax_CV_3','test_acc1_res3','test_acc2_res3','sbstitute_CV_4','simplymax_CV_4','test_acc1_res4','test_acc2_res4'])
print(sample)
sample.to_csv('result_flip_rate={}_high={}.csv'.format(flags.flip_rate, flags.high_env))
